{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import tensorflow as tf\n",
    "import glob\n",
    "import os\n",
    "import random"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "# gpu配置：\n",
    "gpus = tf.config.experimental.list_physical_devices( device_type = 'GPU' )\n",
    "tf.config.experimental.set_visible_devices( devices = gpus[0:3], device_type = 'GPU' )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 测试数据预处理：\n",
    "def preprocessing_test(path):\n",
    "    # 数据读入：image\n",
    "    image = tf.io.read_file( path )                       # 文件读取\n",
    "    image = tf.image.decode_jpeg( image, channels = 3 )   # 文件解码成jpg图片，并给定图片的通道数（默认为0！）\n",
    "    # 等比缩放：resize\n",
    "    image = tf.image.resize( image, [256,256] )           # 等比缩放，不要选裁剪！\n",
    "    image = tf.cast( image, tf.float32 )                  # 转换数据类型：读入默认是int8\n",
    "    image = image / 255                                   # 归一化\n",
    "    # 标签：label\n",
    "    # label = tf.reshape( label, [1] )\n",
    "    # 返回结果：\n",
    "    return image"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 模型导入：\n",
    "model_new = tf.keras.models.load_model( '../model_h5/model_large.h5' )\n",
    "\n",
    "# 导入测试数据：\n",
    "test_path = glob.glob('../test_new1/*')\n",
    "# 随机打散：\n",
    "random.shuffle( test_path )  # 无返回值，直接在原变量上做改变\n",
    "        \n",
    "# 创建数据集：\n",
    "test_dataset = tf.data.Dataset.from_tensor_slices( (test_path) )  # path + label\n",
    "# 加batch维度：\n",
    "AUTOTUNE = tf.data.experimental.AUTOTUNE   # 新操作：在tf.data模块使用时，会自动根据cpu来情况进行并行计算处理！\n",
    "BATCH_SIZE = 16\n",
    "test_dataset = test_dataset.map( preprocessing_test, num_parallel_calls = AUTOTUNE )\n",
    "test_dataset = test_dataset.cache().batch(BATCH_SIZE)\n",
    "test_dataset\n",
    "\n",
    "# 测试：\n",
    "prediction = model_new.predict(test_dataset)\n",
    "\n",
    "# 把标签重归为1和-1：\n",
    "prediction_tmp = prediction.tolist()\n",
    "prediction_tmp = sum( prediction_tmp, [] )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 把标签重归为1和-1：\n",
    "prediction_tmp = prediction.tolist()\n",
    "prediction_tmp = sum( prediction_tmp, [] )\n",
    "\n",
    "label_predict = []\n",
    "label1 = [1]; label2 = [-1]\n",
    "for x in prediction_tmp:\n",
    "    if x > 0.5:\n",
    "        label_predict.append( label1 )\n",
    "    else:\n",
    "        label_predict.append( label2 ) "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 写入文件：\n",
    "with open('prediction.txt', 'w') as f:\n",
    "    for x in label_predict:\n",
    "        s = str( x ) + '\\n'\n",
    "        f.write(s)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
